{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "TD3.ipynb",
      "version": "0.3.2",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "eFY4nKwHXR7x",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.autograd as autograd\n",
        "import torch.optim as optim\n",
        "\n",
        "import gym\n",
        "import random\n",
        "import numpy as np\n",
        "from collections import deque"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TVXINdceXXVt",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def mini_batch_train(env, agent, max_episodes, max_steps, batch_size):\n",
        "    episode_rewards = []\n",
        "\n",
        "    for episode in range(max_episodes):\n",
        "        state = env.reset()\n",
        "        episode_reward = 0\n",
        "        \n",
        "        for step in range(max_steps):\n",
        "            action = agent.get_action(state)\n",
        "            next_state, reward, done, _ = env.step(action)\n",
        "            agent.replay_buffer.push(state, action, reward, next_state, done)\n",
        "            episode_reward += reward\n",
        "\n",
        "            if len(agent.replay_buffer) > batch_size:\n",
        "                agent.update(batch_size)   \n",
        "\n",
        "            if done or step == max_steps-1:\n",
        "                episode_rewards.append(episode_reward)\n",
        "                print(\"Episode \" + str(episode) + \": \" + str(episode_reward))\n",
        "                break\n",
        "\n",
        "            state = next_state"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Qex2Y_nAXsIN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class BasicBuffer:\n",
        "\n",
        "  def __init__(self, max_size):\n",
        "      self.max_size = max_size\n",
        "      self.buffer = deque(maxlen=max_size)\n",
        "\n",
        "  def push(self, state, action, reward, next_state, done):\n",
        "      experience = (state, action, np.array([reward]), next_state, done)\n",
        "      self.buffer.append(experience)\n",
        "\n",
        "  def sample(self, batch_size):\n",
        "      state_batch = []\n",
        "      action_batch = []\n",
        "      reward_batch = []\n",
        "      next_state_batch = []\n",
        "      done_batch = []\n",
        "\n",
        "      batch = random.sample(self.buffer, batch_size)\n",
        "\n",
        "      for experience in batch:\n",
        "          state, action, reward, next_state, done = experience\n",
        "          state_batch.append(state)\n",
        "          action_batch.append(action)\n",
        "          reward_batch.append(reward)\n",
        "          next_state_batch.append(next_state)\n",
        "          done_batch.append(done)\n",
        "\n",
        "      return (state_batch, action_batch, reward_batch, next_state_batch, done_batch)\n",
        "\n",
        "  def __len__(self):\n",
        "      return len(self.buffer)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rPpQ_mNqXXft",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class Critic(nn.Module):\n",
        "\n",
        "    def __init__(self, obs_dim, action_dim):\n",
        "        super(Critic, self).__init__()\n",
        "\n",
        "        self.obs_dim = obs_dim\n",
        "        self.action_dim = action_dim\n",
        "\n",
        "        self.linear1 = nn.Linear(self.obs_dim, 1024)\n",
        "        self.linear2 = nn.Linear(1024 + self.action_dim, 512)\n",
        "        self.linear3 = nn.Linear(512, 300)\n",
        "        self.linear4 = nn.Linear(300, 1)\n",
        "\n",
        "    def forward(self, x, a):\n",
        "        x = F.relu(self.linear1(x))\n",
        "        xa_cat = torch.cat([x,a], 1)\n",
        "        xa = F.relu(self.linear2(xa_cat))\n",
        "        xa = F.relu(self.linear3(xa))\n",
        "        qval = self.linear4(xa)\n",
        "\n",
        "        return qval\n",
        "\n",
        "class Actor(nn.Module):\n",
        "\n",
        "    def __init__(self, obs_dim, action_dim):\n",
        "        super(Actor, self).__init__()\n",
        "\n",
        "        self.obs_dim = obs_dim\n",
        "        self.action_dim = action_dim\n",
        "\n",
        "        self.linear1 = nn.Linear(self.obs_dim, 512)\n",
        "        self.linear2 = nn.Linear(512, 128)\n",
        "        self.linear3 = nn.Linear(128, self.action_dim)\n",
        "\n",
        "    def forward(self, obs):\n",
        "        x = F.relu(self.linear1(obs))\n",
        "        x = F.relu(self.linear2(x))\n",
        "        x = torch.tanh(self.linear3(x))\n",
        "\n",
        "        return x"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "f0mCWDBUXXii",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class TD3Agent:\n",
        "\n",
        "    def __init__(self, env, gamma, tau, buffer_maxlen, delay_step, noise_std, noise_bound, critic_lr, actor_lr):\n",
        "        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "        \n",
        "        self.env = env\n",
        "        self.obs_dim = env.observation_space.shape[0]\n",
        "        self.action_dim = env.action_space.shape[0]\n",
        "\n",
        "        # hyperparameters    \n",
        "        self.gamma = gamma\n",
        "        self.tau = tau\n",
        "        self.noise_std = noise_std\n",
        "        self.noise_bound = noise_bound\n",
        "        self.update_step = 0 \n",
        "        self.delay_step = delay_step\n",
        "        \n",
        "        # initialize actor and critic networks\n",
        "        self.critic1 = Critic(self.obs_dim, self.action_dim).to(self.device)\n",
        "        self.critic2 = Critic(self.obs_dim, self.action_dim).to(self.device)\n",
        "        self.critic1_target = Critic(self.obs_dim, self.action_dim).to(self.device)\n",
        "        self.critic2_target = Critic(self.obs_dim, self.action_dim).to(self.device)\n",
        "        \n",
        "        self.actor = Actor(self.obs_dim, self.action_dim).to(self.device)\n",
        "        self.actor_target = Actor(self.obs_dim, self.action_dim).to(self.device)\n",
        "    \n",
        "        # Copy critic target parameters\n",
        "        for target_param, param in zip(self.critic1_target.parameters(), self.critic1.parameters()):\n",
        "            target_param.data.copy_(param.data)\n",
        "\n",
        "        for target_param, param in zip(self.critic2_target.parameters(), self.critic2.parameters()):\n",
        "            target_param.data.copy_(param.data)\n",
        "\n",
        "        # initialize optimizers        \n",
        "        self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=critic_lr)\n",
        "        self.critic2_optimizer = optim.Adam(self.critic1.parameters(), lr=critic_lr) \n",
        "        self.actor_optimizer  = optim.Adam(self.actor.parameters(), lr=actor_lr)\n",
        "        \n",
        "        self.replay_buffer = BasicBuffer(buffer_maxlen)        \n",
        "\n",
        "    def get_action(self, obs):\n",
        "        state = torch.FloatTensor(obs).unsqueeze(0).to(self.device)\n",
        "        action = self.actor.forward(state)\n",
        "        action = action.squeeze(0).cpu().detach().numpy()\n",
        "\n",
        "        return action\n",
        "    \n",
        "    def update(self, batch_size):\n",
        "        state_batch, action_batch, reward_batch, next_state_batch, masks = self.replay_buffer.sample(batch_size)\n",
        "        state_batch = torch.FloatTensor(state_batch).to(self.device)\n",
        "        action_batch = torch.FloatTensor(action_batch).to(self.device)\n",
        "        reward_batch = torch.FloatTensor(reward_batch).to(self.device)\n",
        "        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)\n",
        "        masks = torch.FloatTensor(masks).to(self.device)\n",
        "        \n",
        "        action_space_noise = self.generate_action_space_noise(action_batch)\n",
        "        next_actions = self.actor.forward(state_batch) + action_space_noise\n",
        "        next_Q1 = self.critic1_target.forward(next_state_batch, next_actions)\n",
        "        next_Q2 = self.critic2_target.forward(next_state_batch, next_actions)\n",
        "        expected_Q = reward_batch + self.gamma * torch.min(next_Q1, next_Q2)\n",
        "\n",
        "        # critic loss\n",
        "        curr_Q1 = self.critic1.forward(state_batch, action_batch)\n",
        "        curr_Q2 = self.critic2.forward(state_batch, action_batch)\n",
        "        critic1_loss = F.mse_loss(curr_Q1, expected_Q.detach())\n",
        "        critic2_loss = F.mse_loss(curr_Q2, expected_Q.detach())\n",
        "        \n",
        "        # update critics\n",
        "        self.critic1_optimizer.zero_grad()\n",
        "        critic1_loss.backward()\n",
        "        self.critic1_optimizer.step()\n",
        "\n",
        "        self.critic2_optimizer.zero_grad()\n",
        "        critic2_loss.backward()\n",
        "        self.critic2_optimizer.step()\n",
        "\n",
        "        # delyaed update for actor & target networks  \n",
        "        if(self.update_step % self.delay_step == 0):\n",
        "            # actor\n",
        "            self.actor_optimizer.zero_grad()\n",
        "            policy_gradient = -self.critic1(state_batch, self.actor(state_batch)).mean()\n",
        "            policy_gradient.backward()\n",
        "            self.actor_optimizer.step()\n",
        "\n",
        "            # target networks\n",
        "            self.update_targets()\n",
        "\n",
        "        self.update_step += 1\n",
        "\n",
        "    def generate_action_space_noise(self, action_batch):\n",
        "        noise = torch.normal(torch.zeros(action_batch.size()), self.noise_std).clamp(-self.noise_bound, self.noise_bound).to(self.device)\n",
        "        return noise\n",
        "\n",
        "    def update_targets(self):\n",
        "        for target_param, param in zip(self.critic1_target.parameters(), self.critic1.parameters()):\n",
        "            target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))\n",
        "\n",
        "        for target_param, param in zip(self.critic2_target.parameters(), self.critic2.parameters()):\n",
        "            target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))\n",
        "        \n",
        "        for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):\n",
        "            target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sn9QzzHoX3qs",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "env = gym.make(\"Pendulum-v0\")\n",
        "gamma = 0.99\n",
        "tau = 1e-2\n",
        "noise_std = 0.2\n",
        "bound = 0.5\n",
        "delay_step = 2\n",
        "buffer_maxlen = 100000\n",
        "critic_lr = 1e-3\n",
        "actor_lr = 1e-3\n",
        "\n",
        "max_episodes = 100\n",
        "max_steps = 500\n",
        "batch_size = 32\n",
        "\n",
        "agent = TD3Agent(env, gamma, tau, buffer_maxlen, delay_step, noise_std, bound, critic_lr, actor_lr)\n",
        "episode_rewards = mini_batch_train(env, agent, 50, 500, 64)"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}